{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.1-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.9.1 64-bit",
   "metadata": {
    "interpreter": {
     "hash": "ae7dbc91df3af0185ffb34e3fe12035cb4fadd9b684f5beca1182a2e24aafe6e"
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "sys.path.append(\"Dataset\")\n",
    "from mnist_mod import load_mnist\n",
    "from CommonFun import forward,sigmoid,softmax"
   ]
  },
  {
   "source": [
    "## Load mnist"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train, t_train), (x_test, t_test) = load_mnist(normalize=False)"
   ]
  },
  {
   "source": [
    "## show mnist size"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "x_train.shape=\t (60000, 784)\nt_train.shape=\t (60000,)\nx_test.shape=\t (10000, 784)\nt_test.shape=\t (10000,)\n"
     ]
    }
   ],
   "source": [
    "print('x_train.shape=\\t',x_train.shape)\n",
    "print('t_train.shape=\\t',t_train.shape)\n",
    "print('x_test.shape=\\t',x_test.shape)\n",
    "print('t_test.shape=\\t',t_test.shape)"
   ]
  },
  {
   "source": [
    "## Show first image in training dataset"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "5\n(784,)->(28, 28)\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "<Figure size 432x288 with 1 Axes>",
      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 251.565 248.518125\" width=\"251.565pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <metadata>\r\n  <rdf:RDF xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\r\n   <cc:Work>\r\n    <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\r\n    <dc:date>2021-03-22T17:23:57.546232</dc:date>\r\n    <dc:format>image/svg+xml</dc:format>\r\n    <dc:creator>\r\n     <cc:Agent>\r\n      <dc:title>Matplotlib v3.3.4, https://matplotlib.org/</dc:title>\r\n     </cc:Agent>\r\n    </dc:creator>\r\n   </cc:Work>\r\n  </rdf:RDF>\r\n </metadata>\r\n <defs>\r\n  <style type=\"text/css\">*{stroke-linecap:butt;stroke-linejoin:round;}</style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n  <g id=\"patch_1\">\r\n   <path d=\"M 0 248.518125 \r\nL 251.565 248.518125 \r\nL 251.565 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n  </g>\r\n  <g id=\"axes_1\">\r\n   <g id=\"patch_2\">\r\n    <path d=\"M 26.925 224.64 \r\nL 244.365 224.64 \r\nL 244.365 7.2 \r\nL 26.925 7.2 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n   </g>\r\n   <g clip-path=\"url(#p5428909d23)\">\r\n    <image height=\"218\" id=\"image944d7a7f5b\" transform=\"scale(1 -1)translate(0 -218)\" width=\"218\" x=\"26.925\" xlink:href=\"data:image/png;base64,\r\niVBORw0KGgoAAAANSUhEUgAAANoAAADaCAYAAADAHVzbAAAGHklEQVR4nO3dvUvW/x7HcX9HaehWbCgIImwwKiLoRogoIiiIguxmcGhtkppagqDF+EHUIDVIQ+B/ULQUQtYQSNLdEARNETgmlEVhdqZz4MC53tKlvvylj8f64uv3O/jkA9eXS/9qaWn51QLMq38t9APAUiA0CBAaBAgNAoQGAUKDAKFBgNAgQGgQIDQIEBoECA0ChAYBQoMAoUGA0CBAaBAgNAgQGgQIDQKEBgFCg4C2hbz533//Xe6XLl2at3u/ffu23B88eFDuU1NT5X7jxo2G28TERHkti48TDQKEBgFCgwChQYDQIEBoECA0CPirZQH/bVN3d3e5z/Qebe/evQ23DRs2NPVMc+Xr168Nt4GBgfLaa9eulfvk5GRTz8TCcaJBgNAgQGgQIDQIEBoECA0ChAYBC/oebbY6OjoaboODg+W1O3fuLPfOzs5mHmlOPHv2rNyr77q1tLS0PHz4sNy/ffv228/E7DjRIEBoECA0CBAaBAgNAoQGAUKDgD/6PdpsrFu3rty3bt1a7rdu3Sr3LVu2/PYzzZXR0dFyv379esPt3r175bXT09NNPdNS50SDAKFBgNAgQGgQIDQIEBoECA0Clux7tNlav359uff29jbc+vr6yms3bdrUzCPNibGxsXLv7+8v9/v378/l4ywaTjQIEBoECA0ChAYBQoMAoUGAj/cXQFdXV7nP9PH/qVOnyn2mVw+z8fPnz3IfHh4u92PHjs3l4/wxnGgQIDQIEBoECA0ChAYBQoMAoUGA92h/oB07dpT7mTNnyn3Pnj0NtyNHjjT1TP/x5s2bct+1a1fDbTH/KTsnGgQIDQKEBgFCgwChQYDQIEBoEOA9Gv/j+/fv5d7W1lbuU1NT5X706NGG28jISHntn8yJBgFCgwChQYDQIEBoECA0CBAaBNQvRfgjtbe3l/uJEycabq2trbO699OnT8t9Mb8rqzjRIEBoECA0CBAaBAgNAoQGAT7e/wNt37693G/evFnuhw8fbvreg4OD5d7f39/0z17MnGgQIDQIEBoECA0ChAYBQoMAoUGA92j/QD09PeV+9+7dcl+1alXT9758+XK5Dw0Nlfv4+HjT917MnGgQIDQIEBoECA0ChAYBQoMAoUGAf9u0ADZv3lzuL168KPeJiYlyf/z4cbmPjY013G7fvl1e++uXX5dmONEgQGgQIDQIEBoECA0ChAYBQoMA30ebJytWrGi43blzp7x25cqV5X727Nlyf/ToUbmT50SDAKFBgNAgQGgQIDQIEBoECA0CvEebJ1evXm24HTx4sLz2yZMn5T48PNzMI7GAnGgQIDQIEBoECA0ChAYBQoMAH+83sHr16nL//Plzua9Zs6bpe8/0NZrp6emmfzYLw4kGAUKDAKFBgNAgQGgQIDQIEBoELNl/23Ty5MlyP378eLm/fPmy3AcGBn73kf7r1atX5X7gwIFyn5ycLPdt27Y13C5evFhee/78+XLn/3OiQYDQIEBoECA0CBAaBAgNAoQGAYv2PVpHR0e5j46OlntnZ+dcPs6cmunZJyYmyv3QoUMNtx8/fpTXzuZ7dkuZEw0ChAYBQoMAoUGA0CBAaBAgNAhYtH/XcePGjeW+du3a0JPMve7u7nn72W1t9a/EuXPnyv3Lly9N33t8fLzcP336VO7v3r1r+t7zzYkGAUKDAKFBgNAgQGgQIDQIEBoELNrvo81kpvdsy5YtK/d9+/aV+/79+xtu7e3t5bWnT58u94X08ePHcn/+/Hm59/T0NNxm+nuUr1+/LvcrV66U+8jISLnPJycaBAgNAoQGAUKDAKFBgNAgYNF+TWYmHz58mNX179+/L/ehoaGGW2tra3ntfP9Jt76+vobb8uXLy2u7urrK/cKFC+Ve/Tm73t7e8trdu3eX+8GDB8vdx/uwyAkNAoQGAUKDAKFBgNAgQGgQsGS/JgNJTjQIEBoECA0ChAYBQoMAoUGA0CBAaBAgNAgQGgQIDQKEBgFCgwChQYDQIEBoECA0CBAaBAgNAoQGAUKDAKFBgNAgQGgQIDQIEBoECA0ChAYBQoMAoUGA0CBAaBAgNAgQGgQIDQKEBgFCgwChQcC/AXA/6ePSIzy5AAAAAElFTkSuQmCC\" y=\"-6.64\"/>\r\n   </g>\r\n   <g id=\"matplotlib.axis_1\">\r\n    <g id=\"xtick_1\">\r\n     <g id=\"line2d_1\">\r\n      <defs>\r\n       <path d=\"M 0 0 \r\nL 0 3.5 \r\n\" id=\"m7a49a789cd\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n      </defs>\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"30.807857\" xlink:href=\"#m7a49a789cd\" y=\"224.64\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_1\">\r\n      <!-- 0 -->\r\n      <g transform=\"translate(27.626607 239.238437)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 31.78125 66.40625 \r\nQ 24.171875 66.40625 20.328125 58.90625 \r\nQ 16.5 51.421875 16.5 36.375 \r\nQ 16.5 21.390625 20.328125 13.890625 \r\nQ 24.171875 6.390625 31.78125 6.390625 \r\nQ 39.453125 6.390625 43.28125 13.890625 \r\nQ 47.125 21.390625 47.125 36.375 \r\nQ 47.125 51.421875 43.28125 58.90625 \r\nQ 39.453125 66.40625 31.78125 66.40625 \r\nz\r\nM 31.78125 74.21875 \r\nQ 44.046875 74.21875 50.515625 64.515625 \r\nQ 56.984375 54.828125 56.984375 36.375 \r\nQ 56.984375 17.96875 50.515625 8.265625 \r\nQ 44.046875 -1.421875 31.78125 -1.421875 \r\nQ 19.53125 -1.421875 13.0625 8.265625 \r\nQ 6.59375 17.96875 6.59375 36.375 \r\nQ 6.59375 54.828125 13.0625 64.515625 \r\nQ 19.53125 74.21875 31.78125 74.21875 \r\nz\r\n\" id=\"DejaVuSans-48\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-48\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"xtick_2\">\r\n     <g id=\"line2d_2\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"69.636429\" xlink:href=\"#m7a49a789cd\" y=\"224.64\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_2\">\r\n      <!-- 5 -->\r\n      <g transform=\"translate(66.455179 239.238437)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 10.796875 72.90625 \r\nL 49.515625 72.90625 \r\nL 49.515625 64.59375 \r\nL 19.828125 64.59375 \r\nL 19.828125 46.734375 \r\nQ 21.96875 47.46875 24.109375 47.828125 \r\nQ 26.265625 48.1875 28.421875 48.1875 \r\nQ 40.625 48.1875 47.75 41.5 \r\nQ 54.890625 34.8125 54.890625 23.390625 \r\nQ 54.890625 11.625 47.5625 5.09375 \r\nQ 40.234375 -1.421875 26.90625 -1.421875 \r\nQ 22.3125 -1.421875 17.546875 -0.640625 \r\nQ 12.796875 0.140625 7.71875 1.703125 \r\nL 7.71875 11.625 \r\nQ 12.109375 9.234375 16.796875 8.0625 \r\nQ 21.484375 6.890625 26.703125 6.890625 \r\nQ 35.15625 6.890625 40.078125 11.328125 \r\nQ 45.015625 15.765625 45.015625 23.390625 \r\nQ 45.015625 31 40.078125 35.4375 \r\nQ 35.15625 39.890625 26.703125 39.890625 \r\nQ 22.75 39.890625 18.8125 39.015625 \r\nQ 14.890625 38.140625 10.796875 36.28125 \r\nz\r\n\" id=\"DejaVuSans-53\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-53\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"xtick_3\">\r\n     <g id=\"line2d_3\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"108.465\" xlink:href=\"#m7a49a789cd\" y=\"224.64\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_3\">\r\n      <!-- 10 -->\r\n      <g transform=\"translate(102.1025 239.238437)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 12.40625 8.296875 \r\nL 28.515625 8.296875 \r\nL 28.515625 63.921875 \r\nL 10.984375 60.40625 \r\nL 10.984375 69.390625 \r\nL 28.421875 72.90625 \r\nL 38.28125 72.90625 \r\nL 38.28125 8.296875 \r\nL 54.390625 8.296875 \r\nL 54.390625 0 \r\nL 12.40625 0 \r\nz\r\n\" id=\"DejaVuSans-49\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-49\"/>\r\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"xtick_4\">\r\n     <g id=\"line2d_4\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"147.293571\" xlink:href=\"#m7a49a789cd\" y=\"224.64\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_4\">\r\n      <!-- 15 -->\r\n      <g transform=\"translate(140.931071 239.238437)scale(0.1 -0.1)\">\r\n       <use xlink:href=\"#DejaVuSans-49\"/>\r\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"xtick_5\">\r\n     <g id=\"line2d_5\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"186.122143\" xlink:href=\"#m7a49a789cd\" y=\"224.64\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_5\">\r\n      <!-- 20 -->\r\n      <g transform=\"translate(179.759643 239.238437)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 19.1875 8.296875 \r\nL 53.609375 8.296875 \r\nL 53.609375 0 \r\nL 7.328125 0 \r\nL 7.328125 8.296875 \r\nQ 12.9375 14.109375 22.625 23.890625 \r\nQ 32.328125 33.6875 34.8125 36.53125 \r\nQ 39.546875 41.84375 41.421875 45.53125 \r\nQ 43.3125 49.21875 43.3125 52.78125 \r\nQ 43.3125 58.59375 39.234375 62.25 \r\nQ 35.15625 65.921875 28.609375 65.921875 \r\nQ 23.96875 65.921875 18.8125 64.3125 \r\nQ 13.671875 62.703125 7.8125 59.421875 \r\nL 7.8125 69.390625 \r\nQ 13.765625 71.78125 18.9375 73 \r\nQ 24.125 74.21875 28.421875 74.21875 \r\nQ 39.75 74.21875 46.484375 68.546875 \r\nQ 53.21875 62.890625 53.21875 53.421875 \r\nQ 53.21875 48.921875 51.53125 44.890625 \r\nQ 49.859375 40.875 45.40625 35.40625 \r\nQ 44.1875 33.984375 37.640625 27.21875 \r\nQ 31.109375 20.453125 19.1875 8.296875 \r\nz\r\n\" id=\"DejaVuSans-50\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-50\"/>\r\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"xtick_6\">\r\n     <g id=\"line2d_6\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"224.950714\" xlink:href=\"#m7a49a789cd\" y=\"224.64\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_6\">\r\n      <!-- 25 -->\r\n      <g transform=\"translate(218.588214 239.238437)scale(0.1 -0.1)\">\r\n       <use xlink:href=\"#DejaVuSans-50\"/>\r\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n   </g>\r\n   <g id=\"matplotlib.axis_2\">\r\n    <g id=\"ytick_1\">\r\n     <g id=\"line2d_7\">\r\n      <defs>\r\n       <path d=\"M 0 0 \r\nL -3.5 0 \r\n\" id=\"mbbb93abec0\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n      </defs>\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#mbbb93abec0\" y=\"11.082857\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_7\">\r\n      <!-- 0 -->\r\n      <g transform=\"translate(13.5625 14.882076)scale(0.1 -0.1)\">\r\n       <use xlink:href=\"#DejaVuSans-48\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"ytick_2\">\r\n     <g id=\"line2d_8\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#mbbb93abec0\" y=\"49.911429\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_8\">\r\n      <!-- 5 -->\r\n      <g transform=\"translate(13.5625 53.710647)scale(0.1 -0.1)\">\r\n       <use xlink:href=\"#DejaVuSans-53\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"ytick_3\">\r\n     <g id=\"line2d_9\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#mbbb93abec0\" y=\"88.74\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_9\">\r\n      <!-- 10 -->\r\n      <g transform=\"translate(7.2 92.539219)scale(0.1 -0.1)\">\r\n       <use xlink:href=\"#DejaVuSans-49\"/>\r\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"ytick_4\">\r\n     <g id=\"line2d_10\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#mbbb93abec0\" y=\"127.568571\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_10\">\r\n      <!-- 15 -->\r\n      <g transform=\"translate(7.2 131.36779)scale(0.1 -0.1)\">\r\n       <use xlink:href=\"#DejaVuSans-49\"/>\r\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"ytick_5\">\r\n     <g id=\"line2d_11\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#mbbb93abec0\" y=\"166.397143\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_11\">\r\n      <!-- 20 -->\r\n      <g transform=\"translate(7.2 170.196362)scale(0.1 -0.1)\">\r\n       <use xlink:href=\"#DejaVuSans-50\"/>\r\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"ytick_6\">\r\n     <g id=\"line2d_12\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"26.925\" xlink:href=\"#mbbb93abec0\" y=\"205.225714\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_12\">\r\n      <!-- 25 -->\r\n      <g transform=\"translate(7.2 209.024933)scale(0.1 -0.1)\">\r\n       <use xlink:href=\"#DejaVuSans-50\"/>\r\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n   </g>\r\n   <g id=\"patch_3\">\r\n    <path d=\"M 26.925 224.64 \r\nL 26.925 7.2 \r\n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n   </g>\r\n   <g id=\"patch_4\">\r\n    <path d=\"M 244.365 224.64 \r\nL 244.365 7.2 \r\n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n   </g>\r\n   <g id=\"patch_5\">\r\n    <path d=\"M 26.925 224.64 \r\nL 244.365 224.64 \r\n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n   </g>\r\n   <g id=\"patch_6\">\r\n    <path d=\"M 26.925 7.2 \r\nL 244.365 7.2 \r\n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n   </g>\r\n  </g>\r\n </g>\r\n <defs>\r\n  <clipPath id=\"p5428909d23\">\r\n   <rect height=\"217.44\" width=\"217.44\" x=\"26.925\" y=\"7.2\"/>\r\n  </clipPath>\r\n </defs>\r\n</svg>\r\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAN8klEQVR4nO3df6jVdZ7H8ddrbfojxzI39iZOrWOEUdE6i9nSyjYRTj8o7FYMIzQ0JDl/JDSwyIb7xxSLIVu6rBSDDtXYMus0UJHFMNVm5S6BdDMrs21qoxjlphtmmv1a9b1/3K9xp+75nOs53/PD+34+4HDO+b7P93zffPHl99f53o8jQgAmvj/rdQMAuoOwA0kQdiAJwg4kQdiBJE7o5sJsc+of6LCI8FjT29qy277C9lu237F9ezvfBaCz3Op1dtuTJP1B0gJJOyW9JGlRROwozMOWHeiwTmzZ50l6JyLejYgvJf1G0sI2vg9AB7UT9hmS/jjq/c5q2p+wvcT2kO2hNpYFoE0dP0EXEeskrZPYjQd6qZ0t+y5JZ4x6/51qGoA+1E7YX5J0tu3v2j5R0o8kbaynLQB1a3k3PiIO2V4q6SlJkyQ9EBFv1NYZgFq1fOmtpYVxzA50XEd+VAPg+EHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEi0P2Yzjw6RJk4r1U045paPLX7p0acPaSSedVJx39uzZxfqtt95arN9zzz0Na4sWLSrO+/nnnxfrK1euLNbvvPPOYr0X2gq77fckHZB0WNKhiJhbR1MA6lfHlv3SiPiwhu8B0EEcswNJtBv2kPS07ZdtLxnrA7aX2B6yPdTmsgC0od3d+PkRscv2X0h6xvZ/R8Tm0R+IiHWS1kmS7WhzeQBa1NaWPSJ2Vc97JD0maV4dTQGoX8thtz3Z9pSjryX9QNL2uhoDUK92duMHJD1m++j3/HtE/L6WriaYM888s1g/8cQTi/WLL764WJ8/f37D2tSpU4vzXn/99cV6L+3cubNYX7NmTbE+ODjYsHbgwIHivK+++mqx/sILLxTr/ajlsEfEu5L+qsZeAHQQl96AJAg7kARhB5Ig7EAShB1IwhHd+1HbRP0F3Zw5c4r1TZs2Feudvs20Xx05cqRYv/nmm4v1Tz75pOVlDw8PF+sfffRRsf7WW2+1vOxOiwiPNZ0tO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwXX2GkybNq1Y37JlS7E+a9asOtupVbPe9+3bV6xfeumlDWtffvllcd6svz9oF9fZgeQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJhmyuwd69e4v1ZcuWFetXX311sf7KK68U683+pHLJtm3bivUFCxYU6wcPHizWzzvvvIa12267rTgv6sWWHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeS4H72PnDyyScX682GF167dm3D2uLFi4vz3njjjcX6hg0binX0n5bvZ7f9gO09trePmjbN9jO2366eT62zWQD1G89u/K8kXfG1abdLejYizpb0bPUeQB9rGvaI2Czp678HXShpffV6vaRr620LQN1a/W38QEQcHSzrA0kDjT5oe4mkJS0uB0BN2r4RJiKidOItItZJWidxgg7opVYvve22PV2Squc99bUEoBNaDftGSTdVr2+S9Hg97QDolKa78bY3SPq+pNNs75T0c0krJf3W9mJJ70v6YSebnOj279/f1vwff/xxy/PecsstxfrDDz9crDcbYx39o2nYI2JRg9JlNfcCoIP4uSyQBGEHkiDsQBKEHUiCsANJcIvrBDB58uSGtSeeeKI47yWXXFKsX3nllcX6008/Xayj+xiyGUiOsANJEHYgCcIOJEHYgSQIO5AEYQeS4Dr7BHfWWWcV61u3bi3W9+3bV6w/99xzxfrQ0FDD2n333Vect5v/NicSrrMDyRF2IAnCDiRB2IEkCDuQBGEHkiDsQBJcZ09ucHCwWH/wwQeL9SlTprS87OXLlxfrDz30ULE+PDxcrGfFdXYgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSILr7Cg6//zzi/XVq1cX65dd1vpgv2vXri3WV6xYUazv2rWr5WUfz1q+zm77Adt7bG8fNe0O27tsb6seV9XZLID6jWc3/leSrhhj+r9ExJzq8bt62wJQt6Zhj4jNkvZ2oRcAHdTOCbqltl+rdvNPbfQh20tsD9lu/MfIAHRcq2H/haSzJM2RNCxpVaMPRsS6iJgbEXNbXBaAGrQU9ojYHRGHI+KIpF9KmldvWwDq1lLYbU8f9XZQ0vZGnwXQH5peZ7e9QdL3JZ0mabekn1fv50gKSe9J+mlENL25mOvsE8/UqVOL9WuuuaZhrdm98vaYl4u/smnTpmJ9wYIFxfpE1eg6+wnjmHHRGJPvb7sjAF3Fz2WBJAg7kARhB5Ig7EAShB1Igltc0TNffPFFsX7CCeWLRYcOHSrWL7/88oa1559/vjjv8Yw/JQ0kR9iBJAg7kARhB5Ig7EAShB1IgrADSTS96w25XXDBBcX6DTfcUKxfeOGFDWvNrqM3s2PHjmJ98+bNbX3/RMOWHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeS4Dr7BDd79uxifenSpcX6ddddV6yffvrpx9zTeB0+fLhYHx4u//XyI0eO1NnOcY8tO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwXX240Cza9mLFo010O6IZtfRZ86c2UpLtRgaGirWV6xYUaxv3LixznYmvKZbdttn2H7O9g7bb9i+rZo+zfYztt+unk/tfLsAWjWe3fhDkv4+Is6V9DeSbrV9rqTbJT0bEWdLerZ6D6BPNQ17RAxHxNbq9QFJb0qaIWmhpPXVx9ZLurZDPQKowTEds9ueKel7krZIGoiIoz9O/kDSQIN5lkha0kaPAGow7rPxtr8t6RFJP4uI/aNrMTI65JiDNkbEuoiYGxFz2+oUQFvGFXbb39JI0H8dEY9Wk3fbnl7Vp0va05kWAdSh6W68bUu6X9KbEbF6VGmjpJskrayeH+9IhxPAwMCYRzhfOffcc4v1e++9t1g/55xzjrmnumzZsqVYv/vuuxvWHn+8/E+GW1TrNZ5j9r+V9GNJr9veVk1brpGQ/9b2YknvS/phRzoEUIumYY+I/5I05uDuki6rtx0AncLPZYEkCDuQBGEHkiDsQBKEHUiCW1zHadq0aQ1ra9euLc47Z86cYn3WrFmttFSLF198sVhftWpVsf7UU08V65999tkx94TOYMsOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0mkuc5+0UUXFevLli0r1ufNm9ewNmPGjJZ6qsunn37asLZmzZrivHfddVexfvDgwZZ6Qv9hyw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSaS5zj44ONhWvR07duwo1p988sli/dChQ8V66Z7zffv2FedFHmzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJR0T5A/YZkh6SNCApJK2LiH+1fYekWyT9b/XR5RHxuybfVV4YgLZFxJijLo8n7NMlTY+IrbanSHpZ0rUaGY/9k4i4Z7xNEHag8xqFfTzjsw9LGq5eH7D9pqTe/mkWAMfsmI7Zbc+U9D1JW6pJS22/ZvsB26c2mGeJ7SHbQ+21CqAdTXfjv/qg/W1JL0haERGP2h6Q9KFGjuP/SSO7+jc3+Q5244EOa/mYXZJsf0vSk5KeiojVY9RnSnoyIs5v8j2EHeiwRmFvuhtv25Lul/Tm6KBXJ+6OGpS0vd0mAXTOeM7Gz5f0n5Jel3Skmrxc0iJJczSyG/+epJ9WJ/NK38WWHeiwtnbj60LYgc5reTcewMRA2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSKLbQzZ/KOn9Ue9Pq6b1o37trV/7kuitVXX29peNCl29n/0bC7eHImJuzxoo6Nfe+rUvid5a1a3e2I0HkiDsQBK9Dvu6Hi+/pF9769e+JHprVVd66+kxO4Du6fWWHUCXEHYgiZ6E3fYVtt+y/Y7t23vRQyO237P9uu1tvR6frhpDb4/t7aOmTbP9jO23q+cxx9jrUW932N5Vrbtttq/qUW9n2H7O9g7bb9i+rZre03VX6Ksr663rx+y2J0n6g6QFknZKeknSoojY0dVGGrD9nqS5EdHzH2DY/jtJn0h66OjQWrb/WdLeiFhZ/Ud5akT8Q5/0doeOcRjvDvXWaJjxn6iH667O4c9b0Yst+zxJ70TEuxHxpaTfSFrYgz76XkRslrT3a5MXSlpfvV6vkX8sXdegt74QEcMRsbV6fUDS0WHGe7ruCn11RS/CPkPSH0e936n+Gu89JD1t+2XbS3rdzBgGRg2z9YGkgV42M4amw3h309eGGe+bddfK8Oft4gTdN82PiL+WdKWkW6vd1b4UI8dg/XTt9BeSztLIGIDDklb1splqmPFHJP0sIvaPrvVy3Y3RV1fWWy/CvkvSGaPef6ea1hciYlf1vEfSYxo57Ognu4+OoFs97+lxP1+JiN0RcTgijkj6pXq47qphxh+R9OuIeLSa3PN1N1Zf3VpvvQj7S5LOtv1d2ydK+pGkjT3o4xtsT65OnMj2ZEk/UP8NRb1R0k3V65skPd7DXv5Evwzj3WiYcfV43fV8+POI6PpD0lUaOSP/P5L+sRc9NOhrlqRXq8cbve5N0gaN7Nb9n0bObSyW9OeSnpX0tqT/kDStj3r7N40M7f2aRoI1vUe9zdfILvprkrZVj6t6ve4KfXVlvfFzWSAJTtABSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBL/DyJ7caZa7LphAAAAAElFTkSuQmCC\n"
     },
     "metadata": {
      "needs_background": "light"
     }
    }
   ],
   "source": [
    "img,label=x_train[0],t_train[0]\n",
    "print(label)\n",
    "print(img.shape,end='->')\n",
    "img=img.reshape(28,28)\n",
    "print(img.shape)\n",
    "plt.imshow(Image.fromarray(np.uint8(img)), cmap='gray')\n",
    "plt.show()"
   ]
  },
  {
   "source": [
    "## load sample weight"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(os.path.join('Dataset',\"sample_weight.pkl\"), 'rb') as f:\n",
    "    network = pickle.load(f)\n",
    "def predict(network, x):\n",
    "    W1, W2, W3 = network['W1'], network['W2'], network['W3']\n",
    "    b1, b2, b3 = network['b1'], network['b2'], network['b3']\n",
    "    z1 = forward(x,W1,b1)\n",
    "    z2 = forward(z1,W2,b2)\n",
    "    z3 = forward(z2,W3,b3)\n",
    "    return z3"
   ]
  },
  {
   "source": [
    "## show w size"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "x:\t (10000, 784)\nx[0]:\t (784,)\nW1:\t (784, 50)\nW2:\t (50, 100)\nW3:\t (100, 10)\n"
     ]
    }
   ],
   "source": [
    "print('x:\\t',x_test.shape)\n",
    "print('x[0]:\\t',x_test[0].shape)\n",
    "print('W1:\\t',network['W1'].shape)\n",
    "print('W2:\\t',network['W2'].shape)\n",
    "print('W3:\\t',network['W3'].shape)"
   ]
  },
  {
   "source": [
    "## test without normalization"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Accuracy:0.9207\n"
     ]
    }
   ],
   "source": [
    "accuracy_cnt = 0\n",
    "for i in range(len(x_test)):\n",
    "    y = predict(network, x_test[i])\n",
    "    p= np.argmax(y)\n",
    "    if p == t_test[i]:\n",
    "        accuracy_cnt += 1\n",
    "print(\"Accuracy:\" + str(float(accuracy_cnt) / len(x_test)))"
   ]
  },
  {
   "source": [
    "## test after normalization"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Accuracy:0.9352\n"
     ]
    }
   ],
   "source": [
    "x,t= load_mnist(normalize=True)[1]\n",
    "accuracy_cnt = 0\n",
    "for i in range(len(x)):\n",
    "    y = predict(network, x[i])\n",
    "    p= np.argmax(y)\n",
    "    if p == t[i]:\n",
    "        accuracy_cnt += 1\n",
    "print(\"Accuracy:\" + str(float(accuracy_cnt) / len(x)))"
   ]
  },
  {
   "source": [
    "## batch"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Accuracy:0.9352\n"
     ]
    }
   ],
   "source": [
    "accuracy_cnt = 0\n",
    "batch_size=100\n",
    "for i in range(0,len(x),batch_size):\n",
    "    y = predict(network, x[i:i+batch_size])\n",
    "    p= np.argmax(y,axis=1)\n",
    "    accuracy_cnt += np.sum(p==t[i:i+batch_size])\n",
    "print(\"Accuracy:\" + str(float(accuracy_cnt) / len(x)))"
   ]
  }
 ]
}